package messages

import (
	"context"
	"database/sql"
	"errors"
	"fmt"
	"github.com/go-kit/kit/log"
	"github.com/go-kit/kit/log/level"
	_ "github.com/go-sql-driver/mysql"
	"github.com/olivere/elastic/v7"
	"strings"
)

const (
	MessageIndex      = "message"
	AlarmMessageIndex = "alarm_message"
)

var (
	ErrorNoConnection = errors.New("网络出现问题,请重试")
)

type FileRow struct {
	Id                 sql.NullString
	Name               sql.NullString
	FirstSeenTimestamp sql.NullInt64
	LastSeenTimestamp  sql.NullInt64
	Size               sql.NullInt64
	Type               sql.NullString
	MD5                sql.NullString
	Format             sql.NullString
	Tags               sql.NullString
	Count              sql.NullInt64
	Path               sql.NullString
}

type MessageRepository interface {
	SearchMessage(_ context.Context, searchMessageOpt messageSearchRequest, size int64, tailData Message) (int64, []*elastic.SearchHit, error)
	SearchContext(_ context.Context, isBefore bool, size int64, data Message) ([]*elastic.SearchHit, error)
	AlarmMessage(_ context.Context, topicIsDefault bool, alarmMessageOpt alarmMessageRequest, size int64, lastData Message) (int64, []*elastic.SearchHit, error)
	GetChatroomIdsByCategoryId(ctx context.Context, categoryId string) ([]string, error)
}

type FileRepository interface {
	SearchFile(_ context.Context, searchFileOpt searchFileRequest, fileDir string) (int64, []*File, error)
	SetTag(_ context.Context, id string, tags []string) ([]tagInfo, error)
	Travel(_ context.Context, id string) ([]*elastic.SearchHit, error)
}

type fileRepository struct {
	mysqlClient   *sql.DB
	elasticClient *elastic.Client // TODO es暂时使用7
	logger        log.Logger
}

type messageRepository struct {
	elasticClient *elastic.Client // TODO es暂时使用7
	mysqlDB       *sql.DB
	logger        log.Logger
}

func (r *messageRepository) GetChatroomIdsByCategoryId(ctx context.Context, categoryId string) ([]string, error) {
	chatroomIds := make([]string, 0)
	rows, err := r.mysqlDB.Query(`SELECT id FROM chatroom WHERE categoryId=?;`, categoryId)
	if err != nil {
		if err == sql.ErrNoRows {
			return nil, errors.New("该分类下没有群")
		}
		_ = r.logger.Log("get chatroom by categoryId error:", err, "categoryId:", categoryId)
		return nil, errors.New("获取分类群失败")
	}
	defer func() {
		err := rows.Close()
		if err != nil {
			_ = r.logger.Log("sql error", err.Error())
		}
	}()
	for rows.Next() {
		var chatroomId string
		err = rows.Scan(&chatroomId)
		if err != nil {
			_ = r.logger.Log("get chatroom by categoryId scan error:", err, "categoryId:", categoryId)
			continue
		}
		chatroomIds = append(chatroomIds, chatroomId)
	}
	return chatroomIds, nil
}

func NewFileRepository(mysqlClient *sql.DB, esClient *elastic.Client, logger log.Logger) (FileRepository, error) {
	return &fileRepository{
		mysqlClient:   mysqlClient,
		elasticClient: esClient,
		logger:        logger,
	}, nil
}

func NewMessageRepository(elasticClient *elastic.Client, sqlDB *sql.DB, logger log.Logger) (MessageRepository, error) {
	return &messageRepository{
		elasticClient: elasticClient,
		mysqlDB:       sqlDB,
		logger:        log.With(logger, "rep", "es client"),
	}, nil
}

func (r *fileRepository) SearchFile(ctx context.Context, searchFileOpt searchFileRequest, fileDir string) (int64, []*File, error) {
	var sqlRows = `SELECT
		id,
		name,
		firstSeenTimestamp,
		lastSeenTimestamp,
		size,
		type,
		format,
		tags,
		path
	FROM file`
	var sqlCount = `SELECT COUNT(id) FROM file`
	var (
		count     int64
		query     []string
		args      []interface{}
		aggValues []string
	)
	if searchFileOpt.Name != "" {
		query = append(query, " name LIKE ?")
		args = append(args, "%"+searchFileOpt.Name+"%")
	}
	if len(searchFileOpt.Types) > 0 {
		query = append(query, fmt.Sprintf(" type in (%s)", "?"+strings.Repeat(",?", len(searchFileOpt.Types)-1)))
		for _, t := range searchFileOpt.Types {
			args = append(args, t)
		}
	}
	if searchFileOpt.StartTimestamp != 0 && searchFileOpt.EndTimestamp != 0 {
		query = append(query, " firstSeenTimestamp >= ?")
		query = append(query, " lastSeenTimestamp <= ?")
		args = append(args, searchFileOpt.StartTimestamp, searchFileOpt.EndTimestamp)
	}
	if len(query) == 0 {
		sqlRows += " LIMIT ?,?"
	} else {
		sqlRows += " WHERE" + strings.Join(query, " AND ") + " LIMIT ?,?"
		sqlCount += " WHERE" + strings.Join(query, " AND ")
	}
	err := r.mysqlClient.QueryRow(sqlCount, args...).Scan(&count)
	if err != nil {
		return 0, nil, err
	}
	{
		beforeCount := searchFileOpt.Size * (searchFileOpt.Page - 1)
		args = append(args, beforeCount, searchFileOpt.Size)
	}
	rows, err := r.mysqlClient.Query(sqlRows, args...)
	if err != nil {
		return 0, nil, err
	}
	defer func() {
		_ = rows.Close()
	}()
	var files []*File
	for rows.Next() {
		var (
			fileRow FileRow
			file    File
		)
		err := rows.Scan(
			&fileRow.Id,
			&fileRow.Name,
			&fileRow.FirstSeenTimestamp,
			&fileRow.LastSeenTimestamp,
			&fileRow.Size,
			&fileRow.Type,
			&fileRow.Format,
			&fileRow.Tags,
			&fileRow.Path,
		)
		if err != nil {
			_ = r.logger.Log("row Scan Error", err)
			continue
		}
		file, _ = fileRowToFile(fileRow)
		file.Path = fileDir + file.Path
		files = append(files, &file)
		aggValues = append(aggValues, file.Id)
	}
	if len(files) > 0 {
		var agg = elastic.NewTermsAggregation().Field("filePath").Size(len(files))
		var topQuery = elastic.NewBoolQuery()
		for _, v := range aggValues {
			topQuery.Should(elastic.NewTermQuery("filePath", v))
		}
		result, err := r.elasticClient.Search(MessageIndex).Query(topQuery).Aggregation("file", agg).Do(ctx)
		if err != nil {
			return 0, nil, err
		}
		termsResult, found := result.Aggregations.Terms("file")
		if !found {
			return 0, nil, errors.New("not found")
		}
		fileCountMap := make(map[string]int64, len(files))
		for _, bucket := range termsResult.Buckets {
			fileCountMap[bucket.Key.(string)] = bucket.DocCount
		}
		for _, f := range files {
			f.Count = fileCountMap[f.Id]
		}
	}
	return count, files, nil
}

func (r *fileRepository) SetTag(ctx context.Context, id string, tags []string) ([]tagInfo, error) {
	var sqlRow = `UPDATE file SET tags = ? WHERE id = ?`
	var args []interface{}
	args = append(args, strings.Join(tags, " "), id)
	_, err := r.mysqlClient.ExecContext(ctx, sqlRow, args...)
	if err != nil {
		_ = level.Error(r.logger).Log("update fail", err)
		return nil, errors.New("更新标签失败")
	}
	return []tagInfo{{setTagRequest{Id: id, Tags: tags}}}, nil
}

func (r *fileRepository) Travel(_ context.Context, id string) ([]*elastic.SearchHit, error) {
	var query = elastic.NewTermQuery("filePath", id)
	resp, err := r.elasticClient.
		Search(MessageIndex).
		Sort("timestamp", false).
		Query(query).
		Do(context.Background())
	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return nil, ErrorNoConnection
	}
	return resp.Hits.Hits, nil
}

func (r *messageRepository) SearchContext(_ context.Context, isAfter bool, size int64, data Message) ([]*elastic.SearchHit, error) {
	var topQuery = elastic.NewBoolQuery()
	timestamp := data.Timestamp
	id := data.Id
	var termQ *elastic.TermQuery
	if len(data.Chatroom.Id) > 0 {
		termQ = elastic.NewTermQuery("chatroom.id", data.Chatroom.Id)
	}
	topQuery.Filter(termQ)
	resp, err := r.elasticClient.
		Search(MessageIndex).
		Size(int(size)).
		Sort("timestamp", isAfter).
		Sort("id", isAfter).
		SearchAfter(timestamp, id).
		Query(topQuery).
		Do(context.Background())
	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return nil, ErrorNoConnection
	}
	return resp.Hits.Hits, nil
}

// 从ES获取指定筛选条件的预警消息
func (r *messageRepository) AlarmMessage(_ context.Context, topicIsDefault bool, alarmMessageOpt alarmMessageRequest, size int64, lastData Message) (int64, []*elastic.SearchHit, error) {
	var (
		skip     int64
		sortAsc  = false
		orderBy  = "timestamp"
		topQuery = elastic.NewBoolQuery()
	)

	{
		nestedIsDefaultQuery := elastic.NewNestedQuery("topics", elastic.NewTermQuery("topics.isDefault", true))
		if topicIsDefault {
			topQuery.Must(nestedIsDefaultQuery)
		} else {
			topQuery.MustNot(nestedIsDefaultQuery)
		}
	}

	if alarmMessageOpt.OrderBy != "" {
		orderBy = alarmMessageOpt.OrderBy
	}
	skip = alarmMessageOpt.Skip

	// 时间筛选
	if alarmMessageOpt.StartTimestamp > 0 && alarmMessageOpt.EndTimestamp > 0 {
		rangeQ := elastic.NewRangeQuery("timestamp")
		rangeQ.Gte(alarmMessageOpt.StartTimestamp)
		rangeQ.Lte(alarmMessageOpt.EndTimestamp)
		topQuery.Filter(rangeQ)
	}
	// 数据源筛选
	if len(alarmMessageOpt.DataSource) > 0 {
		topQuery.Must(elastic.NewTermQuery("dataSource", alarmMessageOpt.DataSource))
	}
	// 预警专题筛选
	if len(alarmMessageOpt.Topic) > 0 {
		nestedQuery := elastic.NewTermQuery("topics.id", alarmMessageOpt.Topic)
		if orderBy == "warningScore" {
			orderBy = "topics"
		}
		topQuery.Must(elastic.NewNestedQuery("topics", nestedQuery))
	}
	// 关键词筛选
	if len(alarmMessageOpt.Keyword) > 0 {
		topQuery.Must(elastic.NewMatchPhraseQuery("content", alarmMessageOpt.Keyword))
	}
	// 采集单位筛选
	if len(alarmMessageOpt.CollectorOrgan) > 0 {
		topQuery.Must(elastic.NewTermQuery("collectorOrgan", alarmMessageOpt.CollectorOrgan))
	}
	// 采集号筛选
	if len(alarmMessageOpt.CollectorId) > 0 {
		topQuery.Must(elastic.NewTermQuery("collectorId", alarmMessageOpt.CollectorId))
	}
	// TODO 用户权限筛选

	var err error
	var resp *elastic.SearchResult
	var sortQuery *elastic.FieldSort
	// 排序方式判断
	if orderBy == "timestamp" {
		if skip < 0 {
			sortQuery = elastic.NewFieldSort("timestamp").Asc()
			sortAsc = true
		} else {
			sortQuery = elastic.NewFieldSort("timestamp").Desc()
		}
		if len(lastData.Id) > 0 {
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				SearchAfter(lastData.Timestamp, lastData.Id).
				Query(topQuery).
				Do(context.Background())
		} else {
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				Query(topQuery).
				Do(context.Background())
		}
	} else if orderBy == "warningScore" {
		if skip < 0 {
			sortQuery = elastic.NewFieldSort("warningScore").Asc()
			sortAsc = true
		} else {
			sortQuery = elastic.NewFieldSort("warningScore").Desc()
		}
		if len(lastData.Id) > 0 {
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				SearchAfter(lastData.WarningScore, lastData.Id).
				Query(topQuery).
				Do(context.Background())
		} else {
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				Query(topQuery).
				Do(context.Background())
		}
	} else {
		if skip < 0 {
			sortQuery = elastic.NewFieldSort("topics.score").Asc().Nested(elastic.NewNestedSort("topics"))
			sortAsc = true
		} else {
			sortQuery = elastic.NewFieldSort("topics.score").Desc().Nested(elastic.NewNestedSort("topics"))
		}
		if len(lastData.Id) > 0 {
			var score float64
			for _, t := range lastData.Topics {
				if t.Name == alarmMessageOpt.Topic {
					score = t.Score
				}
			}
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				SearchAfter(score, lastData.Id).
				Query(topQuery).
				Do(context.Background())
		} else {
			resp, err = r.elasticClient.
				Search(AlarmMessageIndex).
				Size(int(size)).
				SortBy(sortQuery).
				Sort("id", sortAsc).
				Query(topQuery).
				Do(context.Background())
		}
	}

	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return 0, nil, ErrorNoConnection
	}
	return resp.TotalHits(), resp.Hits.Hits, err
}

// 从ES获取指定筛选条件的消息
func (r *messageRepository) SearchMessage(_ context.Context, searchMessageOpt messageSearchRequest, size int64, lastData Message) (int64, []*elastic.SearchHit, error) {
	var topQuery = elastic.NewBoolQuery()
	var sortAsc = false
	var skip int64

	skip = searchMessageOpt.Skip

	// 时间筛选
	if searchMessageOpt.StartTimestamp > 0 && searchMessageOpt.EndTimestamp > 0 {
		rangeQ := elastic.NewRangeQuery("timestamp")
		rangeQ.Gte(searchMessageOpt.StartTimestamp)
		rangeQ.Lte(searchMessageOpt.EndTimestamp)
		topQuery.Filter(rangeQ)
	}
	topQuery.Filter(elastic.NewTermQuery("type", "text"))
	// 数据源筛选
	if len(searchMessageOpt.DataSource) > 0 {
		topQuery.Must(elastic.NewTermQuery("dataSource", searchMessageOpt.DataSource))
	}
	// 关键词筛选
	if len(searchMessageOpt.Keyword) > 0 {
		topQuery.Must(elastic.NewMatchPhraseQuery("content", searchMessageOpt.Keyword))
	}
	// 采集单位筛选
	if len(searchMessageOpt.CollectorOrgan) > 0 {
		topQuery.Must(elastic.NewTermQuery("collectorOrgan", searchMessageOpt.CollectorOrgan))
	}
	// 采集人筛选
	if len(searchMessageOpt.CollectorId) > 0 {
		topQuery.Must(elastic.NewTermQuery("collectorId", searchMessageOpt.CollectorId))
	}
	// 群昵称或群id筛选
	if len(searchMessageOpt.Chatroom) > 0 {
		shouldQuery := elastic.NewBoolQuery().Should()
		shouldQuery.Should(elastic.NewMatchPhraseQuery("chatroom.name", searchMessageOpt.Chatroom))
		shouldQuery.Should(elastic.NewTermQuery("chatroom.id", searchMessageOpt.Chatroom))
		topQuery.Must(shouldQuery)
	}
	// 人昵称或人id筛选
	if len(searchMessageOpt.Member) > 0 {
		shouldQuery := elastic.NewBoolQuery().Should()
		shouldQuery.Should(elastic.NewMatchPhraseQuery("member.nickname", searchMessageOpt.Member))
		shouldQuery.Should(elastic.NewTermQuery("member.id", searchMessageOpt.Member))
		topQuery.Must(shouldQuery)
	}
	// 群类别筛选
	if len(searchMessageOpt.ChatroomIds) > 0 {
		ids := make([]interface{}, len(searchMessageOpt.ChatroomIds))
		for i, id := range searchMessageOpt.ChatroomIds {
			ids[i] = id
		}
		topQuery.Filter(elastic.NewTermsQuery("chatroom.id", ids...))
	}
	//if len(searchMessageOpt.CategoryId) > 0 {
	//	topQuery.Must(elastic.NewTermQuery("chatroom.category", searchMessageOpt.CategoryId))
	//}
	// TODO 用户权限筛选

	var err error
	var resp *elastic.SearchResult
	var sortField = "timestamp"
	if len(searchMessageOpt.Keyword) > 0 {
		sortField = "_score"
	}
	// 分页判断
	if len(lastData.Id) > 0 {
		id := lastData.Id
		timestamp := lastData.Timestamp
		score := lastData.Score

		if skip < 0 {
			sortAsc = true
		}
		if len(searchMessageOpt.Keyword) > 0 {
			resp, err = r.elasticClient.
				Search(MessageIndex).
				Size(int(size)).
				Sort(sortField, sortAsc).
				Sort("id", sortAsc).
				SearchAfter(score, id).
				Query(topQuery).
				Do(context.Background())
		} else {
			resp, err = r.elasticClient.
				Search(MessageIndex).
				Size(int(size)).
				Sort(sortField, sortAsc).
				Sort("id", sortAsc).
				SearchAfter(timestamp, id).
				Query(topQuery).
				Do(context.Background())
		}
	} else {
		resp, err = r.elasticClient.
			Search(MessageIndex).
			Size(int(size)).
			Sort(sortField, sortAsc).
			Sort("id", sortAsc).
			Query(topQuery).
			Do(context.Background())
	}
	if err != nil {
		_ = level.Error(r.logger).Log("elasticsearch connection error:", err)
		return 0, nil, ErrorNoConnection
	}
	return resp.TotalHits(), resp.Hits.Hits, err
}

func fileRowToFile(row FileRow) (File, error) {
	file := File{
		Id:                 row.Id.String,
		Name:               row.Name.String,
		FirstSeenTimestamp: row.FirstSeenTimestamp.Int64,
		LastSeenTimestamp:  row.LastSeenTimestamp.Int64,
		Size:               row.Size.Int64,
		Type:               row.Type.String,
		Format:             row.Format.String,
		Path:               row.Path.String,
	}
	if row.Tags.Valid && row.Tags.String != "" {
		file.Tags = strings.Split(row.Tags.String, " ")
	} else {
		file.Tags = []string{}
	}
	return file, nil
}
